热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

DGLRDKit|基于AttentiveFP可视化训练模型原子权重

DGL具有许多用于化学信息学、药物与生物信息学任务的函数。DGL开发人员提供了用于可视化训练模型原子权重的代码。使用AttentiveFP构建模型后,可以可视化给定

DGL具有许多用于化学信息学、药物与生物信息学任务的函数。

DGL开发人员提供了用于可视化训练模型原子权重的代码。使用Attentive FP构建模型后,可以可视化给定分子的原子权重,意味着每个原子对目标值的贡献量。




基于Attentive FP可视化训练模型原子权重


环境准备


  • PyTorch:深度学习框架
  • DGL:基于PyTorch的库,支持深度学习以处理图形
  • RDKit:用于构建分子图并从字符串表示形式绘制结构式
  • MDTraj:用于分子动力学轨迹分析的开源库



导入库

%matplotlib inline
import matplotlib.pyplot as plt
import os
from rdkit import Chem
from rdkit import RDPathsimport dgl
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from dgl import model_zoofrom dgl.data.chem.utils import mol_to_complete_graph, mol_to_bigraphfrom dgl.data.chem.utils import atom_type_one_hot
from dgl.data.chem.utils import atom_degree_one_hot
from dgl.data.chem.utils import atom_formal_charge
from dgl.data.chem.utils import atom_num_radical_electrons
from dgl.data.chem.utils import atom_hybridization_one_hot
from dgl.data.chem.utils import atom_total_num_H_one_hot
from dgl.data.chem.utils import one_hot_encoding
from dgl.data.chem import CanonicalAtomFeaturizer
from dgl.data.chem import CanonicalBondFeaturizer
from dgl.data.chem import ConcatFeaturizer
from dgl.data.chem import BaseAtomFeaturizer
from dgl.data.chem import BaseBondFeaturizerfrom dgl.data.chem import one_hot_encoding
from dgl.data.utils import split_datasetfrom functools import partial
from sklearn.metrics import roc_auc_score

代码来源于dgl/example

DGL开发人员提供了用于可视化训练模型原子权重的代码。

使用Attentive FP构建模型后,可以可视化给定分子的原子权重,意味着每个原子对目标值的贡献量。

 

def chirality(atom):try:return one_hot_encoding(atom.GetProp('_CIPCode'), ['R', 'S']) + \[atom.HasProp('_ChiralityPossible')]except:return [False, False] + [atom.HasProp('_ChiralityPossible')]def collate_molgraphs(data):"""Batching a list of datapoints for dataloader.Parameters----------data : list of 3-tuples or 4-tuples.Each tuple is for a single datapoint, consisting ofa SMILES, a DGLGraph, all-task labels and optionallya binary mask indicating the existence of labels.Returns-------smiles : listList of smilesbg : BatchedDGLGraphBatched DGLGraphslabels : Tensor of dtype float32 and shape (B, T)Batched datapoint labels. B is len(data) andT is the number of total tasks.masks : Tensor of dtype float32 and shape (B, T)Batched datapoint binary mask, indicating theexistence of labels. If binary masks are notprovided, return a tensor with ones."""assert len(data[0]) in [3, 4], \'Expect the tuple to be of length 3 or 4, got {:d}'.format(len(data[0]))if len(data[0]) == 3:smiles, graphs, labels = map(list, zip(*data))masks = Noneelse:smiles, graphs, labels, masks = map(list, zip(*data))bg = dgl.batch(graphs)bg.set_n_initializer(dgl.init.zero_initializer)bg.set_e_initializer(dgl.init.zero_initializer)labels = torch.stack(labels, dim=0)if masks is None:masks = torch.ones(labels.shape)else:masks = torch.stack(masks, dim=0)return smiles, bg, labels, masksatom_featurizer = BaseAtomFeaturizer({'hv': ConcatFeaturizer([partial(atom_type_one_hot, allowable_set=['B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br', 'Te', 'I', 'At'],encode_unknown=True),partial(atom_degree_one_hot, allowable_set=list(range(6))),atom_formal_charge, atom_num_radical_electrons,partial(atom_hybridization_one_hot, encode_unknown=True),lambda atom: [0], # A placeholder for aromatic information,atom_total_num_H_one_hot, chirality],)})
bond_featurizer = BaseBondFeaturizer({'he': lambda bond: [0 for _ in range(10)]})train_mols = Chem.SDMolSupplier('solubility.train.sdf')
train_smi =[Chem.MolToSmiles(m) for m in train_mols]
train_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in train_mols]).reshape(-1,1)test_mols = Chem.SDMolSupplier('solubility.test.sdf')
test_smi = [Chem.MolToSmiles(m) for m in test_mols]
test_sol = torch.tensor([float(mol.GetProp('SOL')) for mol in test_mols]).reshape(-1,1)train_graph =[mol_to_bigraph(mol,node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for mol in train_mols]test_graph =[mol_to_bigraph(mol,node_featurizer=atom_featurizer, edge_featurizer=bond_featurizer) for mol in test_mols]def run_a_train_epoch(n_epochs, epoch, model, data_loader,loss_criterion, optimizer):model.train()total_loss = 0losses = []for batch_id, batch_data in enumerate(data_loader):batch_datasmiles, bg, labels, masks = batch_dataif torch.cuda.is_available():bg.to(torch.device('cuda:0'))labels = labels.to('cuda:0')masks = masks.to('cuda:0')prediction = model(bg, bg.ndata['hv'], bg.edata['he'])loss = (loss_criterion(prediction, labels)*(masks != 0).float()).mean()#loss = loss_criterion(prediction, labels)#print(loss.shape)optimizer.zero_grad()loss.backward()optimizer.step()losses.append(loss.data.item())#total_score = np.mean(train_meter.compute_metric('rmse'))total_score = np.mean(losses)print('epoch {:d}/{:d}, training {:.4f}'.format( epoch + 1, n_epochs, total_score))return total_scoremodel = model_zoo.chem.AttentiveFP(node_feat_size=39,edge_feat_size=10,num_layers=2,num_timesteps=2,graph_feat_size=200,output_size=1,dropout=0.2)train_loader = DataLoader(dataset=list(zip(train_smi, train_graph, train_sol)), batch_size=128, collate_fn=collate_molgraphs)
test_loader = DataLoader(dataset=list(zip(test_smi, test_graph, test_sol)), batch_size=128, collate_fn=collate_molgraphs)loss_fn = nn.MSELoss(reduction='none')
optimizer = torch.optim.Adam(model.parameters(), lr=10 ** (-2.5), weight_decay=10 ** (-5.0),)
n_epochs = 100
epochs = []
scores = []
for e in range(n_epochs):score = run_a_train_epoch(n_epochs, e, model, train_loader, loss_fn, optimizer)epochs.append(e)scores.append(score)
model.eval()

导入用于分子可视化依赖库

import copy
from rdkit.Chem import rdDepictor
from rdkit.Chem.Draw import rdMolDraw2D
from IPython.display import SVG
from IPython.display import display
import matplotlib
import matplotlib.cm as cm

定义可视化函数


  • 代码来源于DGL库。
  • DGL模型具有get_node_weight选项,该选项返回图形的node_weight。该模型具有两层GRU,因此以下代码我将0用作时间步长,因此时间步长必须为0或1。

def drawmol(idx, dataset, timestep):smiles, graph, _ = dataset[idx]print(smiles)bg = dgl.batch([graph])atom_feats, bond_feats = bg.ndata['hv'], bg.edata['he']if torch.cuda.is_available():print('use cuda')bg.to(torch.device('cuda:0'))atom_feats = atom_feats.to('cuda:0')bond_feats = bond_feats.to('cuda:0')_, atom_weights = model(bg, atom_feats, bond_feats, get_node_weight=True)assert timestep

绘制测试数据集分子

该模型预测溶解度,颜色表示红色是溶解度的积极影响,蓝色是负面影响。

target = test_loader.dataset
for i in range(len(target)):mol, aw, svg = drawmol(i, target, 0)display(SVG(svg))

。。。。。 




参考资料

1. https://github.com/dmlc/dgl/tree/master/apps/life_sci

2. https://github.com/dmlc/dgl/blob/master/python/dgl/model_zoo/chem/attentive_fp.py

3. https://pubs.acs.org/doi/full/10.1021/acs.jcim.9b00387

 


推荐阅读
  • 本文探讨了如何使用Scrapy框架构建高效的数据采集系统,以及如何通过异步处理技术提升数据存储的效率。同时,文章还介绍了针对不同网站采用的不同采集策略。 ... [详细]
  • 我在尝试将组合框转换为具有自动完成功能时遇到了一个问题,即页面上的列表框也被转换成了自动完成下拉框,而不是保持原有的多选列表框形式。 ... [详细]
  • This article explores the process of integrating Promises into Ext Ajax calls for a more functional programming approach, along with detailed steps on testing these asynchronous operations. ... [详细]
  • 探索CNN的可视化技术
    神经网络的可视化在理论学习与实践应用中扮演着至关重要的角色。本文深入探讨了三种有效的CNN(卷积神经网络)可视化方法,旨在帮助读者更好地理解和优化模型。 ... [详细]
  • 基于SSM框架的在线考试系统:随机组卷功能详解
    本文深入探讨了基于SSM(Spring, Spring MVC, MyBatis)框架构建的在线考试系统中,随机组卷功能的设计与实现方法。 ... [详细]
  • 在Android中实现黑客帝国风格的数字雨效果
    本文将详细介绍如何在Android平台上利用自定义View实现类似《黑客帝国》中的数字雨效果。通过实例代码,我们将探讨如何设置文字颜色、大小,以及如何控制数字下落的速度和间隔。 ... [详细]
  • ZOJ 2760 - 最大流问题
    题目链接:How Many Shortest Paths。题目描述:给定一个包含n个节点的有向图,通过一个n*n的矩阵来表示。矩阵中的a[i][j]值为-1表示从节点i到节点j无直接路径;否则,该值表示从i到j的路径长度。输入起点vs和终点vt,计算从vs到vt的所有不共享任何边的最短路径数量。如果起点和终点相同,则输出无穷大。 ... [详细]
  • A1166 峰会区域安排问题(25分)PAT甲级 C++满分解析【图论】
    峰会是指国家元首或政府首脑之间的会议。合理安排峰会的休息区是一项复杂的工作,理想的情况是邀请的每位领导人都是彼此的直接朋友。 ... [详细]
  • 本文介绍了进程的基本概念及其在操作系统中的重要性,探讨了进程与程序的区别,以及如何通过多进程实现并发和并行。文章还详细讲解了Python中的multiprocessing模块,包括Process类的使用方法、进程间的同步与异步调用、阻塞与非阻塞操作,并通过实例演示了进程池的应用。 ... [详细]
  • ED Tree HDU4812 点分治+逆元
    这道题非常巧妙!!!我们进行点分治的时候,算出当前子节点的所有子树中的节点,到当前节点节点的儿子节点的距离,如下图意思就是当前节点的红色节点,我们要求出红色节点的儿子节点绿色节点, ... [详细]
  • 前端技术分享——利用Canvas绘制鼠标轨迹
    作为一名前端开发者,我已经积累了Vue、React、正则表达式、算法以及小程序等方面的技能,但Canvas一直是我的盲区。因此,我在2018年为自己设定了一个新的学习目标:掌握Canvas,特别是如何使用它来创建CSS3难以实现的动态效果。 ... [详细]
  • 本文探讨了在 PHP 的 Zend 框架下,使用 PHPUnit 进行单元测试时遇到的 Zend_Controller_Response_Exception 错误,并提供了解决方案。 ... [详细]
  • Exploring issues and solutions when defining multiple Faust agents programmatically. ... [详细]
  • 本文档旨在提供C语言的基础知识概述,涵盖常量、变量、数据类型、控制结构及函数定义等内容。特别强调了常量的不同类型及其在程序中的应用,以及如何正确声明和使用函数。 ... [详细]
  • 在AngularJS中,有时需要在表单内包含某些控件,但又不希望这些控件导致表单变为脏状态。例如,当用户对表单进行修改后,表单的$dirty属性将变为true,触发保存对话框。然而,对于一些导航或辅助功能控件,我们可能并不希望它们触发这种行为。 ... [详细]
author-avatar
手机用户2502892403
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有